#include "getter_setter_node.hh"

auto plato::GetterSetterNodeCreator::set_static_id(PlatoNodeID static_id)
    -> void {
  static_id_ = static_id;
}

auto plato::GetterSetterNodeCreator::get_static_id() -> PlatoNodeID {
  return static_id_;
}

auto plato::GetterSetterNodeCreator::set_var_id(VarID var_id) -> void {
  var_id_ = var_id;
}

auto plato::GetterSetterNodeCreator::get_var_id() -> VarID { return var_id_; }

plato::GetterCreator::GetterCreator() {}

plato::GetterCreator::~GetterCreator() {}

auto plato::GetterCreator::static_id() -> PlatoNodeID {
  return get_static_id();
}

auto plato::GetterCreator::get_pin_input_index(const std::string & /*pin_name*/)
    -> PlatoPinIndex {
  return INVALID_PIN_INDEX;
}

auto plato::GetterCreator::get_pin_output_index(const std::string &pin_name)
    -> PlatoPinIndex {
  auto it = pin_name_output_index_map_.find(pin_name);
  if (it == pin_name_output_index_map_.end()) {
    return INVALID_PIN_INDEX;
  }
  return it->second;
}

auto plato::GetterCreator::get_node_memory_size() -> std::size_t {
  return get_node_size() + get_pin_size() * pin_name_output_index_map_.size();
}

auto plato::GetterCreator::create(DomainPtr domain_ptr, PlatoNodeID id)
    -> PlatoNodePtr {
  auto var_sync_type = PlatoVariableSyncType::NONE;
  if (sync_type_ == PlatoNodeSyncType::SERVER_SIDE) {
    var_sync_type = PlatoVariableSyncType::SENDER;
  } else if (sync_type_ == PlatoNodeSyncType::CLIENT_SIDE) {
    var_sync_type = PlatoVariableSyncType::RECEIVER;
  }
  auto node_ptr = new_node(domain_ptr, get_static_id(), id, sync_type_);
  node_ptr->set_name("Getter");
  auto pin =
      new_pin(domain_ptr, PlatoPinType::VAR, domain_ptr->get(get_var_id()));
  node_ptr->add_output(pin);
  return node_ptr;
}

auto plato::GetterCreator::add_var_pin(const std::string &name, VarID var_id)
    -> void {
  pin_name_output_index_map_.emplace(name, 0);
  set_var_id(var_id);
}

auto plato::GetterCreator::get_fixed_memory_size() -> std::size_t {
  return get_node_size() + get_pin_size();
}

plato::SetterCreator::SetterCreator() {
  pin_name_input_index_map_.emplace("Do", 0);
  pin_name_output_index_map_.emplace("Done", 0);
}

plato::SetterCreator::~SetterCreator() {}

auto plato::SetterCreator::static_id() -> PlatoNodeID {
  return get_static_id();
}

auto plato::SetterCreator::get_pin_input_index(const std::string &pin_name)
    -> PlatoPinIndex {
  auto it = pin_name_input_index_map_.find(pin_name);
  if (it == pin_name_input_index_map_.end()) {
    return INVALID_PIN_INDEX;
  }
  return it->second;
}

auto plato::SetterCreator::get_pin_output_index(const std::string &pin_name)
    -> PlatoPinIndex {
  auto it = pin_name_output_index_map_.find(pin_name);
  if (it == pin_name_output_index_map_.end()) {
    return INVALID_PIN_INDEX;
  }
  return it->second;
}

auto plato::SetterCreator::get_node_memory_size() -> std::size_t {
  return get_node_size() + get_pin_size() * (pin_name_input_index_map_.size() +
                                             pin_name_output_index_map_.size());
}

auto plato::SetterCreator::create(DomainPtr domain_ptr, PlatoNodeID id)
    -> PlatoNodePtr {
  auto var_sync_type = PlatoVariableSyncType::NONE;
  if (sync_type_ == PlatoNodeSyncType::SERVER_SIDE) {
    var_sync_type = PlatoVariableSyncType::SENDER;
  } else if (sync_type_ == PlatoNodeSyncType::CLIENT_SIDE) {
    var_sync_type = PlatoVariableSyncType::RECEIVER;
  }
  auto node_ptr = new_node(domain_ptr, get_static_id(), id, sync_type_);
  node_ptr->set_name("Setter");
  auto Do_pin = new_pin(domain_ptr, PlatoPinType::EXEC, nullptr);
  node_ptr->add_input(Do_pin);
  auto pin =
      new_pin(domain_ptr, PlatoPinType::VAR, domain_ptr->get(get_var_id()));
  node_ptr->add_input(pin);
  auto Done_pin = new_pin(domain_ptr, PlatoPinType::EXEC, nullptr);
  node_ptr->add_output(Done_pin);
  return node_ptr;
}

auto plato::SetterCreator::add_var_pin(const std::string &name, VarID var_id)
    -> void {
  pin_name_input_index_map_.emplace(name, 1);
  set_var_id(var_id);
}

auto plato::SetterCreator::get_fixed_memory_size() -> std::size_t {
  return get_node_size() + get_pin_size() * 3;
}
